import numpy as np
import matplotlib.pyplot as plt
from util import *
from alg_deterministic import *


L = 50        # tuning
rho = 1
L_zpsgd = 1e3
rho_zpsgd = 1e-2
data_w1a = np.load('data_w1a.npz')
x, y = data_w1a['x'], data_w1a['y']
x = np.array(x[0:1])
y = np.array(y[0:1])
# print(x)
# print(y)

lamda, alpha = 1, 1
f = construct_f_deterministic(x, y, lamda, alpha)

iters = 1000
iters_zpsgd = iters * 200
iters_rspi = 1000
batch_size = len(x)
w = list(np.zeros(len(x[0])))
print(w)

np.random.seed(10)


pagd_complexity, pagd_values = experiment_pagd(f, w, iters, L)
np.savez('pagd', pagd_complexity = pagd_complexity, pagd_values=pagd_values)
zo_gd_ncf_complexity, zo_gd_ncf_vals = experiment_zo_ncf_gd(f, w, iters, L, rho)
np.savez('zo_gd_ncf', zo_gd_ncf_complexity = zo_gd_ncf_complexity, zo_gd_ncf_vals = zo_gd_ncf_vals)
zpsgd_complexity, zpsgd_vals = zpsgd(f, w, iters_zpsgd, batch_size, rho_zpsgd, L_zpsgd)
np.savez('zpsgd', zpsgd_complexity = zpsgd_complexity, zpsgd_vals = zpsgd_vals)
rspi_complexity, rspi_vals = rspi(f, w, iters_rspi, L, sigma_1=5e-4, sigma_2=1e-4, T_sigma_1=20, ratio=0.95)
np.savez('rspi', rspi_complexity = rspi_complexity, rspi_vals = rspi_vals)


# load data
data_pagd = np.load('pagd.npz')
data_zo_gd_ncf = np.load('zo_gd_ncf.npz')
data_zpsgd = np.load('zpsgd.npz')
data_rspi = np.load('rspi.npz')

pagd_complexity = data_pagd['pagd_complexity']
pagd_values = data_pagd['pagd_values']
zo_gd_ncf_complexity = data_zo_gd_ncf['zo_gd_ncf_complexity']
zo_gd_ncf_vals = data_zo_gd_ncf['zo_gd_ncf_vals']
zpsgd_complexity = data_zpsgd['zpsgd_complexity']
zpsgd_vals = data_zpsgd['zpsgd_vals']
rspi_complexity = data_rspi['rspi_complexity']
rspi_vals = data_rspi['rspi_vals']


# plot figures
plt.rcParams.update({'font.size': 14})
plt.figure(figsize=(8, 6))

plt.plot(pagd_complexity[0:500], pagd_values[0:500], label='PAGD')
plt.plot(zo_gd_ncf_complexity, zo_gd_ncf_vals, label='ZO-NCF-GD')
plt.plot(zpsgd_complexity, zpsgd_vals, label='ZPSGD')
plt.plot(rspi_complexity, rspi_vals, label='RSPI')


plt.xlabel('# of Function Query')
plt.ylabel('Objective Function')
plt.legend()
plt.savefig('figures/least_square_deterministic.pdf', bbox_inches='tight')
plt.show()